# Copyright 2020-2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Parameter for cell."""
from __future__ import absolute_import

from copy import copy

import time
import os
import sys
import numbers
import numpy as np

from mindspore import log as logger
from mindspore.log import _LogActionOnce
from mindspore._c_expression import ParamInfo
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor, _TensorMeta
from mindspore.common.hook_handle import _update_hook_version
from mindspore import _checkparam as Validator
from mindspore._check_jit_forbidden_api import jit_forbidden_register
from mindspore._c_expression import TensorPy as Tensor_
from mindspore.parallel._tensor import _get_slice_index
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication._comm_helper import _is_initialized
from mindspore.communication import get_group_size, get_rank
import mindspore.common._monad as monad

__all__ = ['Parameter', 'ParameterTuple']

PARAMETER_NAME_DEFAULT = "Parameter"
_GENERATED_PARAMETER_NAME_PREFIX = PARAMETER_NAME_DEFAULT + '#'
PARAMETER_NAME_PREFIX_MAX_LEN = 1024

_PARAMETER_NAME_ID = 0


def _generate_parameter_name():
    global _PARAMETER_NAME_ID
    name = _GENERATED_PARAMETER_NAME_PREFIX + str(_PARAMETER_NAME_ID)
    _PARAMETER_NAME_ID += 1
    return name


def _is_parameter_generated(param_name):
    if not param_name or not isinstance(param_name, str):
        return False
    return param_name.startswith(_GENERATED_PARAMETER_NAME_PREFIX)


# Global variable for parameter unique key.
_GLOBAL_PARAMETER_KEY = -1


def _is_in_auto_parallel_mode():
    """Get parallel mode."""
    return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]


def _is_parallel_mode():
    """ Whether is parallel mode """
    if not _is_initialized():
        return False
    if os.getenv("RUN_MODE") != "predict":
        return False
    if get_group_size() > 1:
        return True
    return False


def init_to_value(init):
    """
    Get value of initializer.

    Returns:
        Value of the initializer.

    Raises:
        ValueError: The value of the argument 'init' is not correct.
    """
    if isinstance(init, str):
        if init == 'zeros':
            return 0.0
        if init == 'ones':
            return 1.0
        raise ValueError("The argument 'init' should be one of values in ['zeros', 'ones'].")
    if isinstance(init, numbers.Number):
        return float(init)
    raise ValueError("The argument 'init' should be number or string, but got {}.".format(type(init)))


def _get_unique_parameter_key():
    """
    Get parameter unique key.
    Used to identify the same Parameter for Worker and Server in the embedding cache scenario.

    Returns:
        Integer. The unique parameter key.
    """
    global _GLOBAL_PARAMETER_KEY
    _GLOBAL_PARAMETER_KEY += 1
    return _GLOBAL_PARAMETER_KEY


def _gen_offload_file_path(offload_dir):
    offload_dir = os.path.relpath(offload_dir)
    if not os.path.exists(offload_dir):
        os.makedirs(offload_dir, mode=0o700, exist_ok=True)
    offload_file_path = offload_dir + "/" + str(get_rank()) + "_" + str(
        _get_unique_parameter_key()) + "_" + str(time.time()) + ".data"
    return offload_file_path


def _offload_if_config(data):
    """
    Offload parameter(data size > 512) to file when enable memory offload and offload parameter to disk.
    Args:
        data: The parameter data to offload.
    """
    if data is None:
        return

    data_size_threshold = 512
    if data.nbytes < data_size_threshold:
        return

    offload_file_path = data.offload_file_path()
    if offload_file_path is None or offload_file_path == "":
        offload_dir = offload_context.get("offload_path", "./offload")
        offload_file_path = _gen_offload_file_path(offload_dir)
    data.offload(offload_file_path)


class Parameter(Tensor_):
    """
    `Parameter` is a `Tensor` subclass, when they are assigned as Cell attributes they are automatically added to
    the list of its parameters, and will appear, e.g. in `cell.get_parameters()` iterator.

    Note:
        - When using `AutoParallel(cell)` to enable parallel mode, if init `Parameter` by
          a `Tensor`, the type of Parameter will be `Tensor`. `Tensor` will save the shape and type info of a tensor
          with no memory usage.

        - The shape can be changed while
          compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.

        - If there is an operator in the network that requires part of the inputs to be Parameter,
          then the Parameters as this part of the inputs are not allowed to be cast.

        - Give each `Parameter` a unique name to facilitate subsequent operations and updates.
          If there are two or more `Parameter` objects with the same name in a network,
          will be prompted to set a unique name when defining.

        - When directly printing a `Parameter`, you cannot view the actual values contained inside it.
          You need to use the `Parameter.asnumpy()` method to access the actual values.

    Args:
        default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data,
            to initialize the parameter data.
        name (str): Name of the parameter. Default: ``None`` . If two or more `Parameter`
            objects with the same name exist in a network,
            you will be prompted to set a unique name when defining them.

            1) If the parameter is not given a name, the default name is its variable name. For example, the name of
            param_a below is name_a, and the name of param_b is the variable name param_b.

            .. code-block::

                self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
                self.param_b = Parameter(Tensor([2], ms.float32))

            2) If parameter in list or tuple is not given a name, will give it a unique name. For example, the names of
            parameters below are **Parameter$1** and **Parameter$2**.

            .. code-block::

                self.param_list = [Parameter(Tensor([3], ms.float32)),
                                   Parameter(Tensor([4], ms.float32))]

            3) If the parameter is given a name, and the same name exists between different parameters, an exception
            will be thrown. For example, "its name 'name_a' already exists." will be thrown.

            .. code-block::

                self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
                self.param_tuple = (Parameter(Tensor([5], ms.float32), name="name_a"),
                                    Parameter(Tensor([6], ms.float32)))

            4) If a parameter appear multiple times in list or tuple, check the name of the object only once. For
            example, the following example will not throw an exception.

            .. code-block::

                self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
                self.param_tuple = (self.param_a, self.param_a)

        requires_grad (bool): It is Used to filter parameters in :func:`mindspore.nn.Cell.trainable_params()`.
            If it is ``False``, the filter parameters will not be returned in
            :func:`mindspore.nn.Cell.trainable_params()`.
            Default: ``True`` .
        layerwise_parallel (bool): When `layerwise_parallel` is true in data/hybrid parallel mode,
            broadcast and gradients communication would not be applied to the `Parameter`. Default: ``False`` .
        parallel_optimizer (bool): It is used to filter the weight shard operation in parallel mode. It works only when
            enable parallel optimizer in :func:`mindspore.parallel.auto_parallel.AutoParallel.hsdp`. Default: ``True`` .
        storage_format (str): Only Ascend device target is supported. It is used to specify the format of the weight
            loaded to the device. By default, the format is not changed. The optional values are ``"FRACTAL_NZ"`` ,
            ``"NC1HWC0"`` , ``"FRACTAL_Z"`` , etc. Default: ``""`` .
        device (str): Only Ascend device target is supported. It is used to specify the device which the parameter
            is stored. The optional values are ``"CPU"`` and ``"Remote"``. By default, the parameter will be stored
            on NPU while computing. When the device is specified as ``"CPU"``, the parameter will be loaded into the
            device when it needs to be used, and unloaded to the CPU after use. It takes effext only when `jit_level`
            is not ``"O2"`` and `memory_optimize_level` is ``O0`` in :func:`mindspore.set_context`. Less device memory
            is needed when device is specified as ``"CPU"``. The parameter with ``"Remote"`` device type will be stored
            in remote and loaded to device when needed. The ``"Remote"`` device type is an experimental option.

    Examples:
        >>> import numpy as np
        >>> import mindspore
        >>> from mindspore import Parameter, Tensor, ops, nn
        >>>
        >>> class Net(nn.Cell):
        ...     def __init__(self):
        ...         super(Net, self).__init__()
        ...         self.matmul = ops.MatMul()
        ...         self.weight = Parameter(Tensor(np.ones((1, 2)), mindspore.float32), name="w", requires_grad=True)
        ...
        ...     def construct(self, x):
        ...         out = self.matmul(self.weight, x)
        ...         return out
        >>> net = Net()
        >>> x = Tensor(np.ones((2, 1)), mindspore.float32)
        >>> print(net(x))
        [[2.]]
        >>> net.weight.set_data(Tensor(np.zeros((1, 2)), mindspore.float32))
        >>> print(net(x))
        [[0.]]
    """
    _base_type = {}

    def __new__(cls, default_input, *args, **kwargs):
        init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
        rc = sys.getrefcount(default_input)
        init_param = getattr(cls, "init_param", True)
        # pylint: disable-msg=C0123
        if isinstance(default_input, Tensor) and not isinstance(default_input, Parameter) \
                and type(default_input) != Tensor:
            input_class = type(default_input)
            new_type = Parameter._get_combined_class(input_class)
            obj = input_class.__new__(new_type, default_input)
            obj.init_mode = None
            obj.is_default_input_init = init_data_flag
            if obj.has_init:
                obj.init_mode = default_input
            return obj

        input_class, *class_init_args = Parameter._get_parameter_new_args(default_input, rc, init_param)
        new_type = Parameter._get_base_class(input_class)
        obj = input_class.__new__(new_type)
        input_class.__init__(obj, *class_init_args)
        obj.init_mode = None
        obj.is_default_input_init = init_data_flag
        if obj.has_init:
            obj.init_mode = default_input
        return obj

    def __reduce_ex__(self, _):
        data = self
        if self.init_mode is not None:
            data = self.init_mode
        else:
            # cast to break deep infinite loop while deepcopy
            data = Tensor(self)
        return (
            Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))

    def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True,
                 storage_format="", device=None):
        self.param_info = ParamInfo()
        self.name = name
        self.requires_grad = requires_grad
        self.layerwise_parallel = layerwise_parallel
        self.parallel_optimizer = parallel_optimizer
        # this flag for tensor copy data.
        self.init_flag = False
        # this flag is for ge variable copy data.
        self.is_init = False
        self._inited_param = None
        self._sliced = False
        self.requires_aggr = True
        self._cast_type = None
        self._unique = False
        self.is_in_parallel = _is_in_auto_parallel_mode()
        self._pipeline_stage_list = []
        if -1 in self.shape:
            raise ValueError("All shape elements of the Parameter must be positive. But got None.")
        if isinstance(default_input, (Tensor_, Tensor)):
            Tensor_.__init__(self, dtype=default_input.dtype, shape=default_input.shape)
        elif isinstance(default_input, int):
            Tensor_.__init__(self, dtype=mstype.int64, shape=())
        elif isinstance(default_input, float):
            Tensor_.__init__(self, dtype=mstype.float32, shape=())
        elif isinstance(default_input, (np.ndarray, list)):
            Tensor_.__init__(self, default_input)
        else:
            raise TypeError(f"The type of the argument 'default_input' must be in ['Tensor', 'int', 'float',"
                            f" 'numpy.ndarray', 'list']. But got type {type(default_input)}.")
        self.param_info.parameter_shape = self.shape
        self.param_info.storage_format = storage_format
        if device is not None:
            if device == "CPU" or device == "Remote" :
                self._set_user_data("parameter_device", device)
            else:
                raise ValueError(f"Only 'CPU' and 'Remote' is supported for device, but got ${device}.")

        from mindspore.ops.operations import other_ops
        self.load = other_ops.Load()

    def __deepcopy__(self, memodict):
        new_obj = Parameter(self, requires_grad=self.requires_grad)
        new_obj.name = self.name
        new_obj._inited_param = self._inited_param
        return new_obj

    def __str__(self):
        return f'Parameter (name={self.name}, shape={self.shape}, dtype={self.dtype}, ' \
               f'requires_grad={self.requires_grad})'

    def __repr__(self):
        return self.__str__()

    def __parameter__(self):
        """For parse check."""

    @staticmethod
    def _get_base_class(input_class):
        input_class_name = Parameter.__name__
        if input_class_name in Parameter._base_type:
            new_type = Parameter._base_type.get(input_class_name)
        else:
            new_type = type(input_class_name, (Parameter, input_class), {})
            Parameter._base_type[input_class_name] = new_type
        return new_type

    @staticmethod
    def _get_combined_class(tensor_subclass):
        """Create sub class of Parameter and tensor_subclass"""
        class_name = f"Parameter{tensor_subclass.__name__}"

        if class_name in Parameter._base_type:
            return Parameter._base_type[class_name]
        def new_init(self, default_input, *args, **kwargs):
            Parameter.__init__(self, default_input, *args, **kwargs)
            default_input.local_param_info = self.param_info

        new_type = type(class_name, (tensor_subclass, Parameter), {'__init__':new_init})
        Parameter._base_type[class_name] = new_type
        return new_type

    @staticmethod
    def _get_parameter_new_args(data, rc, init_param=True):
        """Set `set_data` of current `Parameter`."""
        if isinstance(data, bool):
            raise ValueError('Parameter data can not be `bool`')
        if isinstance(data, Tensor):
            if not data.has_init:
                if rc == 4:
                    # when ref count is 4, means the input data is not referenced
                    # in other place, so we can make a Tensor without copy data.
                    return (Tensor, data)
                # make a copy of Tensor to init the parameter.
                if data.dtype == mstype.qint4x2:
                    return (Tensor, data.asnumpy(), mstype.qint4x2)
                return (Tensor, data.asnumpy())

            not_init_data = not init_param or _is_in_auto_parallel_mode() or _is_parallel_mode()
            if not_init_data:
                # do not init data while in auto parallel.
                return (Tensor, None, data.dtype, data.shape, data.init)
            return (Tensor, data.init_data())
        if isinstance(data, int):
            return (Tensor, data, mstype.int32)
        if isinstance(data, float):
            return (Tensor, data, mstype.float32)
        return (Tensor, data)

    def copy(self):
        """
        Copy the parameter.

        Returns:
            Parameter, a new parameter.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> y = x.copy()
        """
        return self.clone(init='same')

    @property
    def inited_param(self):
        """
        Get the new parameter after call the init_data.

        Default is a None, If `self` is a Parameter without data, after call the
        `init_data` the initialized Parameter with data will be recorded here.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.inited_param
        """
        return self._inited_param

    @property
    def param_info(self):
        Tensor_.wait_pipeline(self)
        return self._param_info

    @param_info.setter
    def param_info(self, param_info_):
        Tensor_.wait_pipeline(self)
        param_info_.obj = self
        self._param_info = param_info_
        Tensor_.set_param_info(self, param_info_)

    @property
    def name(self):
        """
        Get the name of the parameter.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.name = "param1"
            >>> x.name
            'param1'
        """
        return self.param_info.name

    @name.setter
    def name(self, name_):
        """
        Define a name for the parameter.

        Args:
            name_ (`str` or `None`): The name of the parameter. When the parameter is None or an empty string,
                the default value `PARAMETER_NAME_DEFAULT` is used.
        """
        if name_ is None:
            name_ = _generate_parameter_name()
        elif isinstance(name_, str):
            name_ = name_.strip()
            if name_ == '':
                name_ = _generate_parameter_name()
            if len(name_) > PARAMETER_NAME_PREFIX_MAX_LEN:
                raise ValueError("The length of the '{}' name should be less than {}.".
                                 format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
        else:
            raise ValueError("The type of the Parameter's name should be 'string' or 'None', "
                             "but got {}.".format(type(name_)))

        self.param_info.name = name_

    @property
    def sliced(self):
        """
        Get slice status of the parameter.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.sliced = True
            >>> x.sliced
            True
        """
        return self._sliced

    @sliced.setter
    def sliced(self, sliced_):
        self._sliced = sliced_

    @property
    def comm_fusion(self):
        """
        Get the fusion type (int) for communication operators corresponding to this parameter.

        When using `AutoParallel(cell)` to enable parallel mode, some communication operators used
        for parameters or gradients aggregation are inserted automatically.
        The value of `comm_fusion` must be greater than or equal to 0.
        When the value of `comm_fusion` is ``0`` , operators will not be fused together.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.comm_fusion = 3
            >>> x.comm_fusion
            3
        """
        return self.param_info.comm_fusion

    @comm_fusion.setter
    def comm_fusion(self, comm_fusion_):
        Validator.check_non_negative_int(comm_fusion_)
        self.param_info.comm_fusion = comm_fusion_

    @property
    def parallel_optimizer_comm_recompute(self):
        """
        Get the communication recompute status(bool) of optimizer parallel for the parameter.

        When using `AutoParallel(cell)` to enable parallel mode, and applying parallel optimizer,
        some :class:`mindspore.ops.AllGather` operators
        used for parameters gathering are inserted automatically. It is used to control the recompute attr for those
        :class:`mindspore.ops.AllGather` operators.

        Note:
            - Only `Graph` mode is supported.
            - It is recommended to use cell.recompute(parallel_optimizer_comm_recompute=True/False) to configure
              the AllGather operators introducing by parallel optimizer rather than using this interface directly.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.parallel_optimizer_comm_recompute = True
            >>> x.parallel_optimizer_comm_recompute
            True
        """
        return self.param_info.parallel_optimizer_comm_recompute

    @parallel_optimizer_comm_recompute.setter
    def parallel_optimizer_comm_recompute(self, parallel_optimizer_comm_recompute_):
        Validator.check_bool(parallel_optimizer_comm_recompute_)
        self.param_info.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute_

    @property
    def unique(self):
        """
        Whether the parameter is already unique or not.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.unique = True
            >>> x.unique
            True
        """
        return self._unique

    @unique.setter
    def unique(self, unique_):
        self._unique = unique_

    def clone(self, init='same'):
        """
        Clone the parameter.

        Args:
            init (Union[Tensor, str, numbers.Number]): Initialize the shape and dtype of the parameter.
                If `init` is a `Tensor` or `numbers.Number`, clone a new parameter with the same shape
                and dtype, and the data of the new parameter will be set according to `init`. If `init`
                is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
                For example, if `init` is ``'same'``, clone a new parameter with the same data, shape, and
                dtype. Default: ``'same'``.

        Returns:
            Parameter, a new parameter.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> y = x.clone()
        """
        x = copy(self)
        param_info_clone = self.param_info.clone()
        info = self.param_info
        if hasattr(info, "cloned_obj"):
            info.cloned_obj.append(x)
        else:
            info.cloned_obj = [x]
        self.param_info = info
        param_info_clone.obj = x
        x.param_info = param_info_clone
        x.is_init = False
        x.init = self.init
        x.cache_enable = self.cache_enable
        if x.cache_enable:
            x.key = _get_unique_parameter_key()
        x.requires_aggr = self.requires_aggr
        if self.cache_shape:
            x.cache_shape = self.cache_shape
        if init != 'same':
            shape = self._shape
            dtype = self.dtype
            tensor = initializer(init, shape=shape, dtype=dtype)
            x.set_data(tensor)
            x.init = tensor.init
        device = self._get_user_data("parameter_device")
        if device is not None:
            x._set_user_data("parameter_device", device)
        return x

    @property
    def layerwise_parallel(self):
        """
        Get the layerwise parallel status(bool) of the parameter.

        When `layerwise_parallel` is ``True`` in `DATA_PARALLEL` and `HYBRID_PARALLEL` parallel mode,
        broadcast and gradients communication would not be applied to parameters.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.layerwise_parallel = True
            >>> x.layerwise_parallel
            True
        """
        return self.param_info.layerwise_parallel

    @layerwise_parallel.setter
    def layerwise_parallel(self, value=True):
        if not isinstance(value, bool):
            raise TypeError("The argument `layerwise_parallel` must be bool type.")
        self.param_info.layerwise_parallel = value

    @property
    def parallel_optimizer(self):
        """
        Get the optimizer parallel status(bool) of the parameter.

        When using `AutoParallel(cell)` to enable parallel mode, it is used to filter the weight
        shard operation. It works only when enable parallel optimizer in
        :func:`mindspore.parallel.auto_parallel.AutoParallel.hsdp`.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.parallel_optimizer = True
            >>> x.parallel_optimizer
            True
        """
        return self.param_info.parallel_optimizer

    @parallel_optimizer.setter
    def parallel_optimizer(self, value=True):
        if not isinstance(value, bool):
            raise TypeError("The argument `parallel_optimizer` must be bool type.")
        self.param_info.parallel_optimizer = value

    @property
    def cache_enable(self):
        """
        Return whether the parameter is cache enable.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.cache_enable=True
            >>> x.cache_enable
            True
        """
        return self.param_info.cache_enable

    @cache_enable.setter
    def cache_enable(self, value=True):
        if not isinstance(value, bool):
            raise TypeError("The argument `cache_enable` must be bool type.")
        self.param_info.cache_enable = value

    @property
    def cache_shape(self):
        """
        Return the cache shape corresponding to the parameter if use cache.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.cache_enable=True
            >>> x.cache_shape=[1, 2]
            >>> x.cache_shape
            [1, 2]
        """
        return self.param_info.cache_shape

    @cache_shape.setter
    def cache_shape(self, value):
        if not isinstance(value, (tuple, list)):
            raise TypeError("The argument `cache_shape` must be tuple or list type.")
        self.param_info.cache_shape = value

    @property
    def key(self):
        """
        Return the parameter unique key.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.key = 2
            >>> x.key
            2
        """
        return self.param_info.key

    @key.setter
    def key(self, value=-1):
        """Set the parameter unique key."""
        if not isinstance(value, int):
            raise TypeError("The argument `key` must be int type.")
        self.param_info.key = value

    @property
    def requires_grad(self):
        """
        Return whether the parameter requires gradient.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x.requires_grad = True
            >>> x.requires_grad
            True
        """
        return self.param_info.requires_grad

    @requires_grad.setter
    def requires_grad(self, value=True):
        if not isinstance(value, bool):
            raise TypeError("The argument `requires_grad` must be bool type")
        Tensor_.wait_pipeline(self)
        self.param_info.requires_grad = value
        self._requires_grad = value

    def value(self):
        """
        Return the value of parameter object.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x_value = x.value()
            >>> print(x_value)
            [1.  2.]
        """
        return self.load(self, monad.U)

    def _update_tensor_data(self, data):
        """Update the parameter by a Tensor."""
        if isinstance(self, Tensor):
            self.init_flag = False
            self.init = None
            return self.assign_value(data)
        new_param = Parameter(data, self.name, self.requires_grad)
        new_param.param_info = self.param_info
        return new_param

    @_LogActionOnce(logger=logger, key='add_pipeline_stage')
    def add_pipeline_stage(self, stage):
        """
        Add a pipeline stage to the parameter.

        Args:
            stage(int): The pipeline stage to be added.

        Raise:
            TypeError: If `stage` is not a positive number or not int type.
        """
        logger.warning("This interface may be deleted in the future.")
        if not isinstance(stage, int) or stage < 0:
            raise TypeError("`stage` must be a positive number of int type")
        self._pipeline_stage_list.append(stage)

    def _raise_type_error(self, incoming):
        raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
                        f"Current dtype is {self.dtype}, and incoming is {incoming}. "
                        f"Use .set_dtype(xxx) to change the dtype.")

    @staticmethod
    def _set_data_check_input_valid(current_tensor_is_init, incoming_tensor_is_init):
        if incoming_tensor_is_init and not current_tensor_is_init:
            raise TypeError("The original tensor data is initialized, but the argument 'data' is not initialized."
                            "Please initialize 'data' before call this method.")

    @jit_forbidden_register
    def set_data(self, data, slice_shape=False):
        """
        Set Parameter's data.

        Args:
            data (Union[Tensor, int, float]): New data.
            slice_shape (bool): If slice the parameter is set to ``True``, the shape consistency will not be checked.
                                Default: ``False``. When `slice_shape` is ``True``, and the shapes are not consistent, a
                                ValueError will be thrown.

        Returns:
            Parameter, the parameter after set data.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
            >>> x.set_data(Tensor(np.array([[6, 6], [6, 6]], dtype=np.float32)))
            Parameter (name=param, shape=(2, 2), dtype=Float32, requires_grad=True)
        """
        if not isinstance(data, (Tensor, int, float)):
            raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
                            f"(like `Tensor`). But with type {type(data)}.")
        if isinstance(data, (int, float)):
            if self.dtype in mstype.int_type and isinstance(data, float):
                self._raise_type_error(mstype.float_)
            data = Tensor(data, self.dtype)
        # both not init.
        incoming_tensor_is_init = isinstance(data, Tensor) and not data.has_init
        current_tensor_is_init = isinstance(self, Tensor) and not self.has_init
        if self.dtype != data.dtype:
            from mindspore.ops import functional as F
            if isinstance(data, Tensor) and data.init is not None:
                data.init_data()
            data = F.cast(data, self.dtype)
        if isinstance(data, Tensor) and data.has_init:
            # The parameter has been initialized, directly update by the data
            if current_tensor_is_init:
                self._update_tensor_data(data.init_data())
            else:
                # also update the related inited parameter data
                if self.inited_param is not None:
                    self.inited_param.set_data(data)
                self.init_mode = data
        elif incoming_tensor_is_init or current_tensor_is_init:
            self._update_tensor_data(data)
        self.sliced = slice_shape
        return self

    @staticmethod
    def _get_init_data_args(layout=None):
        """Get the data layout args."""
        init_data_args = ()
        if layout:
            if not isinstance(layout, tuple):
                raise TypeError("The argument 'layout' should be tuple, but got {}.".format(type(layout)))
            if len(layout) < 6:
                raise ValueError("The length of 'layout' must be larger than 5, but got {}.".format(len(layout)))
            slice_index = int(_get_slice_index(layout[0], layout[1], layout[5]))
            init_data_args += (slice_index, layout[2], layout[5])
        return init_data_args

    def init_data(self, layout=None, set_sliced=False):
        """
        Initialize the parameter's data.

        Args:
            layout (Union[None, tuple]): The parameter's layout info.
                layout [dev_mat, tensor_map, slice_shape, filed_size, uniform_split, opt_shard_group].
                Default: ``None``.
                It's not None only when using `AutoParallel(cell)` to enable parallel mode.

                - dev_mat (list(int)): The parameter's device matrix.
                - tensor_map (list(int)): The parameter's tensor map.
                - slice_shape (list(int)): The parameter's slice shape.
                - filed_size (int): The parameter's filed size.
                - uniform_split (bool): Whether the parameter is split evenly.
                - opt_shard_group (str): The group of the parameter while running optimizer parallel.

            set_sliced (bool): True if the parameter is set sliced after initializing the data.
                Default: ``False``.

        Returns:
            Parameter, the `Parameter` after initializing data. If current `Parameter` was already initialized before,
            returns the same initialized `Parameter`.

        Raises:
            RuntimeError: If it is from Initializer, and parallel mode has changed after the Initializer created.
            ValueError: If the length of the layout is less than 6.
            TypeError: If `layout` is not tuple.

        Examples:
            >>> from mindspore import Tensor, Parameter
            >>> import numpy as np
            >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
            >>> x.init_data()
        """
        if self.init_mode is None or not self.has_init:
            return self
        if self.inited_param is not None:
            return self.inited_param

        init_data_args = self._get_init_data_args(layout)

        data = self.init_mode.init_data(*init_data_args)
        origin_dtype = self.dtype
        obj = self._update_tensor_data(data)
        if self.dtype != origin_dtype:
            self.set_dtype(origin_dtype)
        if id(obj) != id(self):
            self._inited_param = obj
        obj.init_mode = None
        obj.sliced = set_sliced
        return obj

    def register_hook(self, hook_fn):
        """
        For details, please refer to :func:`mindspore.Tensor.register_hook`.
        """
        handle = Tensor.register_hook(self, hook_fn)
        _update_hook_version()
        return handle

    def _remove_hook(self):
        _update_hook_version()

    def _offload(self):
        r"""
        Offload parameter to host with releasing device resource.

        Supported Platforms:
            ``Ascend``

        .. warning::
            This is an experimental API that is subject to change or deletion.

        Examples:
            >>> from mindspore import Parameter, Tensor
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x._offload()
        """
        return Tensor_._offload(self, True)

    def _load(self):
        r"""
        Load parameter to device.

        Supported Platforms:
            ``Ascend``

        .. warning::
            This is an experimental API that is subject to change or deletion.

        Examples:
            >>> from mindspore import Parameter, Tensor
            >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
            >>> x._offload()
            >>> x._load()
        """
        return Tensor_._load(self)


# Metaclass to combine _TensorMeta and the instance check override for Buffer.
class _BufferMeta(_TensorMeta):
    # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag.
    def __instancecheck__(cls, instance):
        if cls is _Buffer:
            if isinstance(instance, Tensor) and getattr(instance, "_is_buffer", False):
                return True
        return super().__instancecheck__(instance)


class _Buffer(Tensor, metaclass=_BufferMeta):
    r"""A kind of Tensor that should not be considered a model
    parameter. For example, BatchNorm's `running_mean` is not a parameter, but is part of the Cell's state.

    Buffers are :class:`~mindspore.Tensor` subclasses, that have a
    very special property when used with :class:`~.nn.Cell` s: when they're
    assigned as Cell attributes they are automatically added to the list of
    its buffers, and will appear e.g. in :func:`mindspore.nn.Cell.buffers` iterator.
    Assigning a tensor doesn't have such effect. One can still assign a tensor as a buffer explicitly by using
    the :func:`mindspore.nn.Cell.register_buffer` function.

    Args:
        data (Tensor): buffer tensor.

    Keyword Args:
        persistent (bool, optional): whether the buffer is part of the Cell's
            :attr:`state_dict`. Default ``True``.
    """

    def __new__(cls, data, *, persistent=True):
        if data is None:
            raise ValueError('For create Buffer, input data should not be None')
        if not isinstance(data, Tensor):
            raise TypeError('For create Buffer, type of input data should be Tensor')
        from mindspore.ops import stop_gradient
        t = stop_gradient(data)
        t._is_buffer = True  # pylint: disable=W0212
        t.persistent = persistent
        return t


class ParameterTuple(tuple):
    """
    Inherited from tuple, ParameterTuple  is used to save multiple parameter.

    Note:
        It is used to store the parameters of the network into the parameter tuple collection.

    Examples:
        >>> from mindspore import Tensor, Parameter, ParameterTuple
        >>> import numpy as np
        >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
        >>> y = Parameter(Tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)), name="param1")
        >>> pt = ParameterTuple([x, y])
        >>> pt1 = pt.clone(prefix="new")
    """

    def __new__(cls, iterable):
        """Create instance object of ParameterTuple."""
        data = tuple(iterable)
        ids = set()
        names = set()
        for x in data:
            if not isinstance(x, Parameter):
                raise TypeError(f"For ParameterTuple initialization, "
                                f"ParameterTuple input should be 'Parameter' collection, "
                                f"but got a {type(iterable)}. ")
            if id(x) not in ids:
                if x.name in names:
                    raise ValueError("The value {} , its name '{}' already exists. "
                                     "Please set a unique name for the parameter.".format(x, x.name))
                names.add(x.name)
                ids.add(id(x))
        return tuple.__new__(ParameterTuple, tuple(data))

    def clone(self, prefix, init='same'):
        """
        Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple.

        Args:
            prefix (str): Namespace of parameter, the prefix string will be added to the names of parameters
                in parametertuple.

            init (Union[Tensor, str, numbers.Number]): Clone the shape and dtype of Parameters in ParameterTuple and
                set  data according to `init`. Default: ``'same'``.

                - If `init` is a `Tensor` , set the new Parameter data to the input Tensor.
                - If `init` is a `str`, data will be set according to the initialization method of the same name in
                  the `Initializer`. When it is ``'same'``, the new Parameter will have the same value
                  with the original Parameter.
                - If `init` is `numbers.Number` , set the new Parameter data to the input number.

        Returns:
            Tuple, the new Parameter tuple.
        """
        Validator.check_str_by_regular(prefix)
        new = []
        for x in self:
            x1 = x.clone(init)
            x1.name = prefix + "." + x1.name
            new.append(x1)

            if not x1.cache_enable:
                continue

        return ParameterTuple(new)

    def __parameter_tuple__(self):
        """For parse check."""
